import argparse
import asyncio
import json
import os
import re
from datetime import datetime
from typing import Dict, List, Optional, Tuple, Any, Union
import aiohttp 

class SimpleGSM8KSolver:
    """Simplified GSM8K solver with direct IO mode"""
    
    def __init__(self):
        self.model = "your model" 
        self.base_url = "your base_url" 
        self.token_counts = [0, 0] 
        self.stats = {
            "total_problems": 0,
            "correct_answers": 0,
            "incorrect_answers": 0,
            "accuracy": 0.0
        }
    
    async def generate(self, prompt: str) -> str:
        try:
            async with aiohttp.ClientSession() as session:
                payload = {
                    "model": self.model,
                    "messages": [{"role": "user", "content": prompt}],
                    "temperature": 0.3,
                    "max_tokens": 8000,
                    "top_p": 0.8
                }
                
                async with session.post(
                    f"{self.base_url}/chat/completions",
                    json=payload,
                    timeout=aiohttp.ClientTimeout(total=120)
                ) as response:
                    resp = await response.json()
                    
                    input_tokens = len(prompt) // 4
                    output_tokens = len(resp["choices"][0]["message"]["content"]) // 4
                    self.token_counts[0] += input_tokens
                    self.token_counts[1] += output_tokens
                    
                    return resp["choices"][0]["message"]["content"]
        except Exception as e:
            print(f"LLM Error: {str(e)}")
            raise
    
    def _extract_answer(self, text: str) -> Optional[str]:
        """Extract answer from response text, remove ALL dots (.) and keep only digits"""
        # Try to find boxed answer first
        boxed_pattern = r'\\boxed\{([^{}]+)\}'
        boxed_matches = re.findall(boxed_pattern, text)
        if boxed_matches:
            raw_answer = boxed_matches[-1]
        else:
            # Then look for final answer line
            final_answer_match = re.search(
                r'Final\s+Answer\s*:\s*([^\n]+)', 
                text, 
                re.IGNORECASE
            )
            if final_answer_match:
                raw_answer = final_answer_match.group(1).strip()
            else:
                # Try to find any number that appears to be an answer
                last_number_match = re.search(
                    r'(\d+)\D*$',  # Match last number in the text
                    text
                )
                if last_number_match:
                    raw_answer = last_number_match.group(1)
                else:
                    return None
        
        # Clean the answer - remove ALL dots (.) and other non-digit characters
        cleaned_answer = re.sub(r'[^\d]', '', raw_answer) 
        return cleaned_answer if cleaned_answer else None 
    async def solve_problem(self, question: str) -> Dict[str, Any]:
        """Directly solve a math problem"""
        prompt = f"""
Let's think step by step, provide the final answer in the format "Final Answer: [your answer]".

Question: Farmer Brown has 20 animals on his farm, all either chickens or cows. They have a total of 70 legs, all together. How many of the animals are chickens?
Thought 1: Let C be the number of chickens, then (20 - C) is the number of cows.
Thought 2: Chickens have 2 legs, cows have 4 legs.
Thought 3: Total legs is 2C + 4(20 - C) = 70.
Calculation: 2C + 80 - 4C = 70 → -2C + 80 = 70 → 2C = 10 → C = 5
Final Answer: 5

Question: Henry and 3 of his friends order 7 pizzas for lunch. Each pizza is cut into 8 slices. If Henry and his friends want to share the pizzas equally, how many slices can each of them have?
Thought 1: Each pizza has 8 slices, and there are 7 pizzas.
Calculation: 7 × 8 = 56
Thought 2: There are 4 people in total (Henry + 3 friends).
Calculation: 56 ÷ 4 = 14
Final Answer: 14

Question:{question}
"""
        
        response = await self.generate(prompt)
        answer = self._extract_answer(response)
        
        return {
            "response": response,
            "answer": answer,
            "tokens": self.token_counts.copy()
        }
    
    async def load_problems(self, dataset_path: str, start_idx: int, end_idx: int) -> List[Dict]:
        """Load math problems from dataset"""
        try:
            with open(dataset_path, "r", encoding="utf-8") as f:
                data = json.load(f)
                return data[start_idx:end_idx]
        except Exception as e:
            print(f"Error loading dataset: {str(e)}")
            return []
    
    def _extract_correct_answer(self, solution: str) -> Optional[str]:
        """Extract correct answer from solution with multiple pattern support"""
        if not solution:
            return None
        
        # Pattern 1: #### answer (GSM8K original format)
        hash_pattern = r'####\s*([^\n]+)'
        hash_matches = re.findall(hash_pattern, solution)
        if hash_matches:
            return hash_matches[-1].strip()
        
        # Pattern 2: \boxed{answer} (common LaTeX format)
        boxed_pattern = r'\\boxed\{([^{}]+)\}'
        boxed_matches = re.findall(boxed_pattern, solution)
        if boxed_matches:
            return boxed_matches[-1].strip()
        
        # Pattern 3: Final Answer: answer (explicit declaration)
        final_answer_pattern = r'Final\s+Answer\s*:\s*([^\n]+)'
        final_match = re.search(final_answer_pattern, solution, re.IGNORECASE)
        if final_match:
            return final_match.group(1).strip()
        
        # Pattern 4: The answer is [answer]
        answer_is_pattern = r'The\s+answer\s+is\s+([^\n]+)'
        answer_is_match = re.search(answer_is_pattern, solution, re.IGNORECASE)
        if answer_is_match:
            return answer_is_match.group(1).strip()
        
        return None
    
    def update_stats(self, is_correct: bool):
        """Update statistics"""
        self.stats["total_problems"] += 1
        if is_correct:
            self.stats["correct_answers"] += 1
        else:
            self.stats["incorrect_answers"] += 1
        
        if self.stats["total_problems"] > 0:
            self.stats["accuracy"] = (
                self.stats["correct_answers"] / self.stats["total_problems"] * 100
            )

async def main():
    parser = argparse.ArgumentParser(description="Simple GSM8K Solver")
    parser.add_argument("--start", type=int, default=0, help="Start index in dataset")
    parser.add_argument("--end", type=int, default=1, help="End index in dataset")
    parser.add_argument("--dataset", type=str, default="gsm8k.json", help="Path to dataset")
    args = parser.parse_args()
    
    # Create output directory if it doesn't exist
    os.makedirs("log/gsm8k", exist_ok=True)
    
    solver = SimpleGSM8KSolver()
    problems = await solver.load_problems(args.dataset, args.start, args.end)
    results = []
    
    for idx, problem in enumerate(problems, args.start):
        if "question" not in problem:
            print(f"\n{'='*50}\nSkipping problem {idx}: No 'question' field\n{'='*50}")
            continue
        
        print(f"\n{'='*50}\nProcessing problem {idx}: {problem['question'][:50]}...\n{'='*50}")
        
        # Reset token counts for each problem
        solver.token_counts = [0, 0]
        
        result = await solver.solve_problem(problem["question"])
        
        # Prepare verification
        correct_answer = solver._extract_correct_answer(problem.get("answer", ""))
        is_correct = False
        
        if correct_answer and result["answer"]:
            is_correct = str(result["answer"]).strip() == str(correct_answer).strip()
            solver.update_stats(is_correct)
        
        # Prepare result record
        record = {
            "problem_id": idx,
            "question": problem["question"],
            "response": result["response"],
            "answer": result["answer"],
            "correct_answer": correct_answer,
            "is_correct": is_correct,
            "tokens": result["tokens"]
        }
        results.append(record)
        
        #print(f"\nExecution Summary:")
        print(f"Answer: {result['answer']}")
        print(f"Correct answer: {correct_answer}")
        print(f"Verification: {'CORRECT' if is_correct else 'INCORRECT'}")
        #print(f"Tokens used: {result['tokens']}")
    
    # Save results
    if results:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"log/gsm8k/results_{args.start}_{args.end}_acc{solver.stats['accuracy']:.2f}%.json"
        
        output = {
            "results": results,
            "statistics": solver.stats
        }
        
        with open(filename, "w", encoding="utf-8") as f:
            json.dump(output, f, indent=2, ensure_ascii=False)
        
        print(f"\n{'='*50}\nFinal Statistics\n{'='*50}")
        print(f"Results saved to {filename}")
        print(f"Total problems processed: {solver.stats['total_problems']}")
        print(f"Correct answers: {solver.stats['correct_answers']}")
        print(f"Incorrect answers: {solver.stats['incorrect_answers']}")
        print(f"Overall accuracy: {solver.stats['accuracy']:.2f}%")
        print(f"{'='*50}\n")

if __name__ == "__main__":
    asyncio.run(main())